GRU

将多层门控循环单元 (GRU) RNN 应用于输入序列。

GRU 网络模型中有两个门:更新门和重置门。将两个连续的时间节点表示为 \(t - 1\)\(t\)。给定一个在时刻 \(t\) 的输入 \(x_t\),一个隐藏状态 \(h_{t-1}\),在时刻 \(t\) 的更新门和重置门使用门控制机制计算。更新门 \(z_t\) 用于控制前一时刻的状态信息被带入到当前状态中的程度,重置门 \(r_t\) 控制前一状态有多少信息被写入到当前候选集 \(n_t\)

对于输入序列中的每个元素,每一层计算以下函数:

\begin{align*} r_t &= \sigma(W_{ir}x_t + b_{ir} + W_{hr}h_{(t-1)} + b_{hr}) \\ z_t &= \sigma(W_{iz}x_t + b_{iz} + W_{hz}h_{(t-1)} + b_{hz}) \\ n_t &= \tanh(W_{in}x_t + b_{in} + r_t \odot (W_{hn}h_{(t-1)} + b_{hn})) \\ h_t &= (1-z_t) \odot n_t + z_t \odot h_{(t-1)} \end{align*}

其中 \(\sigma\) 是 sigmoid 激活函数,\(\odot\) 是 Hadamard 积(逐元素乘积)。\(W, b\) 是公式中输出和输入之间的可学习权重。例如,\(W_{ir}, b_{ir}\) 是用于将输入 \(x_t\) 转换为 \(r_t\) 的权重和偏置。

注意,本算子中候选门 \(n_t\) 的计算与原始论文和Mindspore框架略有不同。在原始实现中,\(r_t\) 和上一隐藏状态 \(h_{(t-1)}\) 之间的 Hadamard 积 (\(\odot\)) 在与权重矩阵 \(W\) 相乘和加上偏置之前进行:

\[n_t = \tanh(W_{in}x_t + b_{in} + W_{hn}(r_t \odot h_{(t-1)}) + b_{hn})\]

本算子采用 PyTorch 实现方式,是在 \(W_{hn}h_{(t-1)}\) 之后完成的:

\[n_t = \tanh(W_{in}x_t + b_{in} + r_t \odot (W_{hn}h_{(t-1)} + b_{hn}))\]
输入:
  • input - 输入数据的地址。

  • weight_g - 可学习的输入-隐藏权重的地址。

  • weight_r - 可学习的隐藏-隐藏权重的地址。

  • input_bias - 可学习的输入-隐藏偏置的地址。

  • state_bias - 可学习的隐藏-隐藏偏置的地址。

  • hidden_state - 初始隐藏状态的地址。

  • buffer - 用于存储中间计算结果。

  • gru_param - 算子计算所需参数的结构体。其各成员见下述。

  • core_mask - 核掩码。

GruParameter定义:

 1typedef struct GruParameter {
 2    int input_size_; // 输入input中预期特征的数量
 3    int hidden_size_; // 隐藏状态h中的特征数量
 4    int seq_len_; // 输入batch中每个序列的长度
 5    int batch_; // 总批次数
 6    int output_step_; // 每次循环中output步长
 7    int bidirectional_; // 是否为双向GRU
 8    int input_row_align_; // 输入行对齐值
 9    int input_col_align_; // 输入列对齐值
10    int state_row_align_; // 隐藏状态行对齐值
11    int state_col_align_; // 隐藏状态列对齐值
12    int check_seq_len_; // 进行计算的序列长度
13} GruParameter;
输出:
  • output - 输出地址。

  • hidden_state - 最终的隐藏状态。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持int8, fp32

  • MT7004 支持fp16, fp32

共享存储版本:

void i8_Gru_s(int8_t *output, int8_t *input, int8_t *weight_g, int8_t *weight_r, int8_t *input_bias, int8_t *state_bias, int8_t *hidden_state, int8_t *buffer[4], GruParameter *gru_param, int core_mask)
void hp_Gru_s(half *output, half *input, half *weight_g, half *weight_r, half *input_bias, half *state_bias, half *hidden_state, half *buffer[4], GruParameter *gru_param, int core_mask);
void fp_Gru_s(float *output, float *input, float *weight_g, float *weight_r, float *input_bias, float *state_bias, float *hidden_state, float *buffer[4], GruParameter *gru_param, int core_mask);

C调用示例:

 1void TestGruSMCFp32(int check_seq_len, int seq_len, int batch_size, int input_size, int bidirectional, int hidden_size, int core_mask) {
 2    int core_id = get_core_id();
 3    int logic_core_id = GetLogicCoreId(core_mask, core_id);
 4    int core_num = GetCoreNum(core_mask);
 5    float *output = (void*)0x88000000;
 6    float *input = (void*)0x88100000;
 7    float *weight_g = (void*)0x88200000;
 8    float *weight_r = (void*)0x88300000;
 9    float *input_bias = (void*)0x88400000;
10    float *state_bias = (void*)0x88500000;
11    float *hidden_state = (void*)0x88600000;
12    float** buffer = (float**)0x88700000;
13    float *output_hidden_state = (void*)0x88800000;
14    GruParameter* param = (GruParameter*)0x88900000;
15    int hidden_state_batch = 1;
16    int num_directions = 1;
17    if (bidirectional) {
18        hidden_state_batch = hidden_state_batch * 2;
19        num_directions = num_directions * 2;
20    }
21    int input_col_align = hidden_size;
22    int state_col_align = hidden_size;
23    if (logic_core_id == 0) {
24        memcpy(output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float));
25        memcpy(check_output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float));
26        buffer[0] = (void*)0x88A00000;
27        buffer[1] = (void*)0x88B00000;
28        buffer[2] = (void*)0x88C00000;
29        buffer[3] = (void*)0x88D00000;
30        param->batch_ = batch_size;
31        param->bidirectional_ = bidirectional;
32        param->hidden_size_ = hidden_size;
33        param->input_col_align_ = input_col_align;
34        param->input_size_ = input_size;
35        param->output_step_ = batch_size * hidden_size * num_directions;
36        param->seq_len_ = seq_len;
37        param->state_col_align_ = state_col_align;
38        param->check_seq_len_ = check_seq_len;
39    }
40    sys_bar(0, core_num); // 初始化参数完成后进行同步
41    fp_Gru_s(output, input, weight_g, weight_r, input_bias, state_bias, output_hidden_state, buffer, param, core_mask);
42}
43
44void main() {
45    int check_seq_len = 2;
46    int seq_len = 2;
47    int batch_size = 2;
48    int input_size = 2;
49    int bidirectional = 0;
50    int hidden_size = 2;
51    int core_mask = 0b1111;
52    TestGruSMCFp32(check_seq_len, seq_len, batch_size, input_size, bidirectional, hidden_size, core_mask);
53}

私有存储版本:

void i8_Gru_p(int8_t *output, int8_t *input, int8_t *weight_g, int8_t *weight_r, int8_t *input_bias, int8_t *state_bias, int8_t *hidden_state, int8_t *buffer[4], GruParameter *gru_param, int core_mask)
void hp_Gru_p(half *output, half *input, half *weight_g, half *weight_r, half *input_bias, half *state_bias, half *hidden_state, half *buffer[4], GruParameter *gru_param, int core_mask);
void fp_Gru_p(float *output, float *input, float *weight_g, float *weight_r, float *input_bias, float *state_bias, float *hidden_state, float *buffer[4], GruParameter *gru_param, int core_mask);

C调用示例:

 1void TestGruL2Fp32(int check_seq_len, int seq_len, int batch_size, int input_size, int bidirectional, int hidden_size, int core_mask) {
 2    float *output = (void*)0x10000000; // 私有存储版本地址设置在AM内
 3    float *input = (void*)0x10004000;
 4    float *weight_g = (void*)0x10008000;
 5    float *weight_r = (void*)0x1000C000;
 6    float *input_bias = (void*)0x10010000;
 7    float *state_bias = (void*)0x10014000;
 8    float *hidden_state = (void*)0x10018000;
 9    float** buffer = (float**)0x1001C000;
10    float *output_hidden_state = (void*)0x10020000;
11    GruParameter* param = (GruParameter*)0x10024000;
12    int hidden_state_batch = 1;
13    int num_directions = 1;
14    if (bidirectional) {
15        hidden_state_batch = hidden_state_batch * 2;
16        num_directions = num_directions * 2;
17    }
18    int input_col_align = hidden_size;
19    int state_col_align = hidden_size;
20    memcpy(output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float));
21    memcpy(check_output_hidden_state, hidden_state, hidden_state_batch * batch_size * hidden_size * sizeof(float));
22    buffer[0] = (void*)0x10030000;
23    buffer[1] = (void*)0x10034000;
24    buffer[2] = (void*)0x10038000;
25    buffer[3] = (void*)0x1003C000;
26    param->batch_ = batch_size;
27    param->bidirectional_ = bidirectional;
28    param->hidden_size_ = hidden_size;
29    param->input_col_align_ = input_col_align;
30    param->input_size_ = input_size;
31    param->output_step_ = batch_size * hidden_size * num_directions;
32    param->seq_len_ = seq_len;
33    param->state_col_align_ = state_col_align;
34    param->check_seq_len_ = check_seq_len;
35    fp_Gru_p(output, input, weight_g, weight_r, input_bias, state_bias, output_hidden_state, buffer, param, core_mask);
36}
37
38void main() {
39    int check_seq_len = 2;
40    int seq_len = 2;
41    int batch_size = 2;
42    int input_size = 2;
43    int bidirectional = 0;
44    int hidden_size = 2;
45    int core_mask = 0b0001; // 私有存储版本只能设置为一个核心启动
46    TestGruL2Fp32(check_seq_len, seq_len, batch_size, input_size, bidirectional, hidden_size, core_mask);
47    return 0;
48}